{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import evaluate\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from transformers import (\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorWithPadding,\n",
    ")\n",
    "from dataclasses import dataclass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class BertCLS:\n",
    "    def __init__(\n",
    "        self,\n",
    "        model,\n",
    "        tokenizer,\n",
    "        train_dataset=None,\n",
    "        eval_dataset=None,\n",
    "        output_dir=\"output\",\n",
    "        epoch=3,\n",
    "    ):\n",
    "        self.model = model\n",
    "        self.tokenizer = tokenizer\n",
    "        self.train_dataset = train_dataset\n",
    "        self.eval_dataset = eval_dataset\n",
    "\n",
    "        self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
    "\n",
    "        self.args = self.get_args(output_dir, epoch)\n",
    "\n",
    "        self.trainer = Trainer(\n",
    "            model=self.model,\n",
    "            args=self.args,\n",
    "            train_dataset=self.train_dataset,\n",
    "            eval_dataset=self.eval_dataset,\n",
    "            data_collator=self.data_collator,\n",
    "            # compute_metrics=compute_metrics,\n",
    "            tokenizer=tokenizer,\n",
    "        )\n",
    "\n",
    "    def get_args(self, output_dir, epoch):\n",
    "        if self.eval_dataset:\n",
    "            args = TrainingArguments(\n",
    "                output_dir=output_dir,\n",
    "                evaluation_strategy=\"epoch\",\n",
    "                save_strategy=\"epoch\",\n",
    "                save_total_limit=3,\n",
    "                learning_rate=2e-5,\n",
    "                num_train_epochs=epoch,\n",
    "                weight_decay=0.01,\n",
    "                per_device_train_batch_size=32,\n",
    "                per_device_eval_batch_size=16,\n",
    "                # logging_steps=16,\n",
    "                save_safetensors=True,\n",
    "                overwrite_output_dir=True,\n",
    "                load_best_model_at_end=True,\n",
    "            )\n",
    "        else:\n",
    "            args = TrainingArguments(\n",
    "                output_dir=output_dir,\n",
    "                evaluation_strategy=\"no\",\n",
    "                save_strategy=\"epoch\",\n",
    "                save_total_limit=3,\n",
    "                learning_rate=2e-5,\n",
    "                num_train_epochs=epoch,\n",
    "                weight_decay=0.01,\n",
    "                per_device_train_batch_size=32,\n",
    "                per_device_eval_batch_size=16,\n",
    "                # logging_steps=16,\n",
    "                save_safetensors=True,\n",
    "                overwrite_output_dir=True,\n",
    "                # load_best_model_at_end=True,\n",
    "            )\n",
    "        return args\n",
    "\n",
    "    def set_args(self, args):\n",
    "        \"\"\"\n",
    "        从外部重新设置 TrainingArguments，args 更新后，trainer也进行更新\n",
    "        \"\"\"\n",
    "        self.args = args\n",
    "\n",
    "        self.trainer = Trainer(\n",
    "            model=self.model,\n",
    "            args=self.args,\n",
    "            train_dataset=self.train_dataset,\n",
    "            eval_dataset=self.eval_dataset,\n",
    "            data_collator=self.data_collator,\n",
    "            # compute_metrics=compute_metrics,\n",
    "            tokenizer=self.tokenizer,\n",
    "        )\n",
    "\n",
    "    def train(self, epoch=None, over_write=False):\n",
    "        if epoch:\n",
    "            self.args.num_train_epochs = epoch\n",
    "            \n",
    "        best_model_path = os.path.join(self.args.output_dir, \"best_model\")\n",
    "\n",
    "        if over_write or not os.path.exists(best_model_path):\n",
    "            self.trainer.train()\n",
    "            self.trainer.save_model(best_model_path)\n",
    "        else:\n",
    "            print(\n",
    "                f\"预训练权重 {best_model_path} 已存在，且over_write={over_write}。不启动模型训练！\"\n",
    "            )\n",
    "\n",
    "    def eval(self, eval_dataset):\n",
    "        predictions = self.trainer.predict(eval_dataset)\n",
    "        preds = np.argmax(predictions.predictions, axis=-1)\n",
    "        metric = evaluate.load(\"glue\", \"mrpc\")\n",
    "        return metric.compute(predictions=preds, references=predictions.label_ids)\n",
    "\n",
    "    def pred(self, pred_dataset):\n",
    "        predictions = self.trainer.predict(pred_dataset)\n",
    "        preds = np.argmax(predictions.predictions, axis=-1)\n",
    "        return pred_dataset.add_column(\"pred\", preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"bert-base-chinese\"\n",
    "\n",
    "bert = AutoModelForSequenceClassification.from_pretrained(\n",
    "    model_name,\n",
    "    trust_remote_code=True,\n",
    "    num_labels=2,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"lansinuote/ChnSentiCorp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': '选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',\n",
       " 'label': 1}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoded_dataset = dataset.map(\n",
    "    # preprocess_data, \n",
    "    # batched=True, \n",
    "    # remove_columns=dataset['train'].column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_func_pad(item):\n",
    "    global tokenizer\n",
    "    tokenized_inputs = tokenizer(\n",
    "        item[\"text\"],\n",
    "        max_length=512,\n",
    "        truncation=True,\n",
    "        padding=\"max_length\"\n",
    "    )\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_func(item):\n",
    "    global tokenizer\n",
    "    tokenized_inputs = tokenizer(\n",
    "        item[\"text\"],\n",
    "        max_length=512,\n",
    "        truncation=True,\n",
    "    )\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9cc72618e5c8421fa73a73f062affeb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/9600 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_dataset_pad = ds[\"train\"].map(\n",
    "    tokenize_func_pad, remove_columns=[\"text\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dataset_pad.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9600, 512])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_dataset_pad[\"input_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da9775d04e14420b958d49084be5935c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/9600 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_dataset = ds[\"train\"].map(tokenize_func, remove_columns=[\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dataset[\"input_ids\"][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dataset.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 9600\n",
       "})"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 126])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = [batch_dataset[i] for i in range(16)]\n",
    "data_collator(data)[\"input_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 9600\n",
       "})"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_dataset_pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='300' max='300' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [300/300 03:09, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "BertCLS(bert, tokenizer, batch_dataset_pad, epoch=1).train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "105"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(batch_dataset[0][\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jie/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='300' max='300' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [300/300 02:19, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "BertCLS(bert, tokenizer, batch_dataset, epoch=1).train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
